{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "df = pd.read_csv(\"../data/voxceleb2_train_ps.csv\")\n",
    "df[\"Video\"] = [file.split(\"/\")[-2] for file in df[\"File\"]]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute avg number of videos per speaker\n",
    "unique_videos_per_speaker = df.groupby('Speaker')['Video'].nunique().reset_index()\n",
    "unique_videos_per_speaker[\"Video\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "centroids = torch.load(\"../our_centroids.pt\", map_location='cpu')[0]\n",
    "sim = centroids @ centroids.T\n",
    "\n",
    "sim.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp = df.groupby('Video')['Speaker'].nunique().reset_index()\n",
    "tmp[tmp[\"Speaker\"] > 2]\n",
    "# unique_videos_per_speaker[\"Video\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster_to_spk = {}\n",
    "\n",
    "for cluster in df[\"Speaker_ps\"].unique():\n",
    "    cluster_to_spk[cluster] = df[df[\"Speaker_ps\"] == cluster][\"Speaker\"].value_counts().idxmax()\n",
    "\n",
    "len(cluster_to_spk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_scores = []\n",
    "neg_scores = []\n",
    "clusters = df[\"Speaker_ps\"].unique().tolist()\n",
    "\n",
    "sim_np = sim.cpu().numpy()\n",
    "\n",
    "for cluster in clusters[:100]:\n",
    "    pos_scores += [\n",
    "        sim_np[cluster, c]\n",
    "        for c in clusters\n",
    "        if cluster_to_spk[c] == cluster_to_spk[cluster]\n",
    "    ]\n",
    "    neg_scores += [\n",
    "        sim_np[cluster, c]\n",
    "        for c in clusters\n",
    "        if cluster_to_spk[c] != cluster_to_spk[cluster]\n",
    "    ]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "pos_scores = 1 - np.array(pos_scores)\n",
    "neg_scores = 1 - np.array(neg_scores)\n",
    "\n",
    "x = np.linspace(min(min(pos_scores), min(neg_scores)), max(max(pos_scores), max(neg_scores)), 1000)\n",
    "y_pos = gaussian_kde(pos_scores)(x)\n",
    "y_neg = gaussian_kde(neg_scores)(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 6))\n",
    "\n",
    "plt.plot(x, y_pos, label='Positives', color='green')\n",
    "plt.plot(x, y_neg, label='Negatives', color='red')\n",
    "\n",
    "plt.title('Distributions of distances')\n",
    "plt.xlabel('Distance')\n",
    "plt.ylabel('Density')\n",
    "# plt.yscale('log')\n",
    "# plt.xscale('log')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "embeddings = torch.load(\"../embeddings_100_full_vox2.pt\", map_location='cpu')\n",
    "embeddings = F.normalize(embeddings)\n",
    "embeddings.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def simulation(\n",
    "    video_threshold = 0.5,\n",
    "    prob_decay = 0.2,\n",
    "    speaker_threshold = 0.9,\n",
    "    count = 5000,\n",
    "    verbose = False,\n",
    "    seed = 0\n",
    "):\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    res = []\n",
    "\n",
    "    count_total = 0\n",
    "    speaker_acc_total = 0\n",
    "    video_acc_total = 0\n",
    "\n",
    "    # for i in tqdm(range(len(df))):\n",
    "    for i in tqdm(np.random.randint(0, len(df), size=(count,))):\n",
    "        speaker = df.iloc[i][\"Speaker\"]\n",
    "        video = df.iloc[i][\"Video\"]\n",
    "        cluster = df.iloc[i][\"Speaker_ps\"]\n",
    "        if verbose:\n",
    "            print(f\"Current sample: {i}, Speaker: {speaker}, Video: {video}\")\n",
    "\n",
    "        # Determine nearby clusters\n",
    "        dists = sim[cluster]\n",
    "        nearby_clusters = torch.nonzero(dists > video_threshold).view(-1)\n",
    "        nearby_clusters = nearby_clusters[torch.sort(sim[cluster, nearby_clusters], descending=True).indices]\n",
    "        nearby_clusters = nearby_clusters[1:]\n",
    "        if len(nearby_clusters) == 0:\n",
    "            res.append(None)\n",
    "            continue\n",
    "        if verbose:\n",
    "            print(f\"Nearby clusters sim:\", sim[cluster, nearby_clusters])\n",
    "            print(f\"Nearby clusters idx:\", nearby_clusters)\n",
    "\n",
    "        # Sample one random cluster\n",
    "        probabilities = prob_decay ** torch.arange(len(nearby_clusters)).float()\n",
    "        probabilities = probabilities / probabilities.sum()\n",
    "        random_cluster = nearby_clusters[torch.multinomial(probabilities, 1)]\n",
    "        if verbose:\n",
    "            print(f\"Selected cluster sim: {sim[cluster, random_cluster]}\")\n",
    "            print(f\"Selected cluster idx: {random_cluster}\")\n",
    "\n",
    "        # Get all samples from cluster\n",
    "        samples = df[df[\"Speaker_ps\"] == random_cluster.item()]\n",
    "\n",
    "        samples_dist = (embeddings[samples.index] @ centroids[random_cluster].T).view(-1)\n",
    "        \n",
    "        nearby_samples = torch.nonzero(samples_dist > speaker_threshold).view(-1)\n",
    "        nearby_samples = nearby_samples[torch.sort(samples_dist[nearby_samples], descending=True).indices]\n",
    "        if len(nearby_samples) == 0:\n",
    "            res.append(None)\n",
    "            continue\n",
    "\n",
    "        probabilities = prob_decay ** torch.arange(len(nearby_samples)).float()\n",
    "        probabilities = probabilities / probabilities.sum()\n",
    "        random_sample = nearby_samples[torch.multinomial(probabilities, 1)]\n",
    "        sample = samples.iloc[random_sample.item()]\n",
    "\n",
    "        res.append(sample[\"File\"])\n",
    "\n",
    "        speaker_acc = int(sample[\"Speaker\"] == speaker)\n",
    "        video_acc = int(sample[\"Video\"] == video)\n",
    "\n",
    "        speaker_acc_total += speaker_acc\n",
    "        video_acc_total += video_acc\n",
    "        if verbose:\n",
    "            print(f\"Speaker accuracy: {speaker_acc}\")\n",
    "            print(f\"Video accuracy: {video_acc}\")\n",
    "\n",
    "        count_total += 1\n",
    "\n",
    "    if count_total != 0:\n",
    "        speaker_acc_total /= count_total\n",
    "        video_acc_total /= count_total\n",
    "\n",
    "    # df[\"File2\"] = res\n",
    "\n",
    "    return speaker_acc_total, video_acc_total, count_total / count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simulation(\n",
    "    video_threshold=0.835,\n",
    "    prob_decay=0.5,\n",
    "    speaker_threshold=0.94\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(0.9093381686310064, 0.2470534904805077, 0.4412)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df\n",
    "df.to_csv(\"../data/voxceleb2_train_ps.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import optuna\n",
    "\n",
    "def objective(trial):\n",
    "    speaker_acc, video_acc, coverage = simulation(\n",
    "        video_threshold=0.835,#trial.suggest_float('video_threshold', 0.8, 0.85),\n",
    "        prob_decay=0.5, #trial.suggest_float('prob_decay', 0.0, 1.0),\n",
    "        speaker_threshold=trial.suggest_float('speaker_threshold', 0.85, 0.99)\n",
    "    )\n",
    "    \n",
    "    score = speaker_acc# + (1 - video_acc)# + coverage\n",
    "\n",
    "    return score\n",
    "\n",
    "study = optuna.create_study(direction='maximize')\n",
    "study.optimize(objective, n_trials=25)\n",
    "\n",
    "study.best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "{'video_threshold': 0.6101654597940125,\n",
    " 'prob_decay': 0.03608832601140681,\n",
    " 'speaker_threshold': 0.8391936107587641}\n",
    "\n",
    "{'video_threshold': 0.8251956372897714,\n",
    " 'prob_decay': 0.7986839078081102,\n",
    " 'speaker_threshold': 0.9086315021218292}\n",
    "\n",
    " {'video_threshold': 0.8365454665594554}\n",
    " '''\n",
    "\n",
    "simulation(\n",
    "    0.835, #study.best_params[\"video_threshold\"],\n",
    "    0.5, #study.best_params[\"prob_decay\"],\n",
    "    0.94, #study.best_params[\"speaker_threshold\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "n = 10\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "for decay_factor in [0.5, 1]:\n",
    "    probabilities = decay_factor ** torch.arange(n).float()\n",
    "    probabilities = (probabilities / probabilities.sum()).numpy()\n",
    "    plt.plot(probabilities, marker='o', label=f'{decay_factor}')\n",
    "    print(probabilities)\n",
    "\n",
    "# probabilities = torch.ones(n) / n\n",
    "\n",
    "plt.xlabel('Index')\n",
    "plt.ylabel('Probability')\n",
    "plt.title('Exponentially Decreasing Probability Distribution')\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
